# ------------------------------------------------------------------------------
# Models and predicts stent in tested with cardiologist's features
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 05_car_features_model.sh {split}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(glue)
library(Matrix)
library(doParallel)
library(glmnet)
library(optparse)
library(here)
library(testit)
library(tidyverse)
library(reticulate) # source_python()

u <- modules::use(here::here("lib", "util.R"))
source_python(here::here("lib", "python_util.py"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(make_option("--split", type = "character"))
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
models_dir <- file.path(paths$modeling$dir, "models", split, "car")
pred_dir <- file.path(paths$modeling$dir, "prediction", split, "car")
dir.create(models_dir, showWarnings = FALSE)
dir.create(pred_dir, showWarnings = FALSE)

# Load Data --------------------------------------------------------------------
message("Loading data...")
ecgs <- readRDS(paths$analysis$ecgs) %>%
  select(ed_enc_id, has_ecg_waveform, npy_index)
ids_train <- readRDS(glue(paths$modeling$train)) %>%
  u$safe_left_join(ecgs) %>%
  filter(has_ecg_waveform)
ids_test <- readRDS(glue(paths$modeling$test)) %>%
  u$safe_left_join(ecgs) %>%
  filter(has_ecg_waveform)

ecg_feats_raw <- read_pickle_file(paths$features$ecgs)

# Prep Features ----------------------------------------------------------------
message("Preparing feature set for models...")
keep_feats <- names(ecg_feats_raw)[grepl("has_", names(ecg_feats_raw))]
ecg_feats <- ecg_feats_raw %>%
  select(npy_index, all_of(keep_feats))

x_train <- ids_train %>%
  select(ed_enc_id, npy_index) %>%
  u$safe_left_join(ecg_feats) %>%
  select(-ed_enc_id, -npy_index) %>%
  as.matrix()

x_test <- ids_test %>%
  select(ed_enc_id, npy_index) %>%
  u$safe_left_join(ecg_feats) %>%
  select(-ed_enc_id, -npy_index) %>%
  as.matrix()

# Select Obs -------------------------------------------------------------------
message("Selecting observations...")
keep_obs <- which(
  !ids_train$excl_flag_c_int &
    !ids_train$excl_flag_chronic & !ids_train$excl_flag_death
)
ids <- ids_train[keep_obs, ]
x <- x_train[keep_obs, ]

keep_obs_tested <- which(ids$test_010_day)
ids_tested <- ids[keep_obs_tested, ]
x_tested <- x[keep_obs_tested, ]

# LASSO ------------------------------------------------------------------------
message("Modeling stent using LASSO...")
registerDoParallel(cores = uniqueN(ids$train_fold))
tuning_result <- cv.glmnet(
  x = x_tested,
  y = ids_tested[["stent_or_cabg_010_day"]],
  foldid = ids_tested$train_fold,
  parallel = TRUE,
  family = "binomial",
  type.measure = "auc",
  alpha = 1
)
stopImplicitCluster()

message("Modeling test using LASSO...")
registerDoParallel(cores = uniqueN(ids$train_fold))
tuning_result_test <- cv.glmnet(
  x = x,
  y = ids[["test_010_day"]],
  foldid = ids$train_fold,
  parallel = TRUE,
  family = "binomial",
  type.measure = "auc",
  alpha = 1
)
stopImplicitCluster()

# Predictions ------------------------------------------------------------------
message("Predicting in test set...")
predictions <- ids_test %>%
  select(ptid, ed_enc_id, test_010_day, stent_or_cabg_010_day) %>%
  setDT() %>%
  .[, "p__lasso__stent_or_cabg_010_day__tested__car" := predict(tuning_result, x_test, s = "lambda.min", type = "response")] %>%
  .[, "z__lasso__stent_or_cabg_010_day__tested__car" := predict(tuning_result, x_test, s = "lambda.min", type = "link")] %>%
  .[, "p__lasso__test_010_day__all__car" := predict(tuning_result_test, x_test, s = "lambda.min", type = "response")] %>%
  .[, "z__lasso__test_010_day__all__car" := predict(tuning_result_test, x_test, s = "lambda.min", type = "link")]

# Save -------------------------------------------------------------------------
model_fp_stent <- file.path(
  models_dir, "lasso__stent_or_cabg_010_day__tested.rds"
)
message("Saving LASSO stent model to ", model_fp_stent, "...")
write_rds(tuning_result, model_fp_stent)

model_fp_test <- file.path(
  models_dir, "lasso__test_010_day__all.rds"
)
message("Saving LASSO test model to ", model_fp_test, "...")
write_rds(tuning_result, model_fp_test)

prediction_fp <- file.path(pred_dir, "scores_test_car.rds")
message("Saving predictions to ", prediction_fp, "...")
write_rds(predictions, prediction_fp)

message("Done.")
